package com.tsfyun.common.base.support;

import org.apache.commons.io.IOUtils;

import javax.servlet.ReadListener;
import javax.servlet.ServletInputStream;
import javax.servlet.http.HttpServletRequest;
import java.io.BufferedReader;
import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.InputStreamReader;
import java.nio.charset.Charset;
import java.util.*;

/**
 * 主要是解决request请求流只能读一次的问题。
 * 通过重写HttpServletRequestWrapper把request的保存下来，然后通过过滤器保存下来的request在填充进去，这样就可以多次读取request了。
 */
public class HttpServletRequestWrapper extends javax.servlet.http.HttpServletRequestWrapper {

    HttpServletRequest orgRequest;

    private Map<String, String[]> parameterMap;

	private final byte[] bytes;

	//请求头
	private Map<String,String> headers = new HashMap<>();

	public HttpServletRequestWrapper(HttpServletRequest request) throws IOException {
		super(request);
        orgRequest = request;
        parameterMap = request.getParameterMap();
		// 读取输入流里的请求参数，并保存到bytes里
		bytes = IOUtils.toByteArray(request.getInputStream());
	}

	public String getRequestBodyParame() {
		return new String(bytes, Charset.forName("utf8"));
	}


    /**
     * 获取最原始的request
     *
     * @return
     */
    public HttpServletRequest getOrgRequest() {
        return orgRequest;
    }

    /**
     * 获取最原始的request的静态方法
     *
     * @return
     */
    public static HttpServletRequest getOrgRequest(HttpServletRequest req) {
        if (req instanceof HttpServletRequestWrapper) {
            return ((HttpServletRequestWrapper) req).getOrgRequest();
        }
        return req;
    }


    /**
	 *
	 * <p>
	 * Title: getInputStream
	 * </p>
	 * <p>
	 * Description:处理POST请求参数 RequestBody is missing 问题
	 * </p>
	 *
	 * @return
	 * @throws IOException
	 * @see javax.servlet.ServletRequestWrapper#getInputStream()
	 */
	@Override
	public ServletInputStream getInputStream() throws IOException {
		String body = new String(this.bytes);
		return new BufferedServletInputStream(body.getBytes());
	}

	class BufferedServletInputStream extends ServletInputStream {
		private ByteArrayInputStream inputStream;

		public BufferedServletInputStream(byte[] buffer) {
			// 此处即赋能，可以详细查看ByteArrayInputStream的该构造函数；
			this.inputStream = new ByteArrayInputStream(buffer);
		}

		@Override
		public int available() throws IOException {
			return inputStream.available();
		}

		@Override
		public int read() throws IOException {
			return inputStream.read();
		}

		@Override
		public int read(byte[] b, int off, int len) throws IOException {
			return inputStream.read(b, off, len);
		}

		@Override
		public boolean isFinished() {
			return false;
		}

		@Override
		public boolean isReady() {
			return false;
		}

		@Override
		public void setReadListener(ReadListener listener) {

		}
	}

	@Override
	public BufferedReader getReader() throws IOException {
		return new BufferedReader(new InputStreamReader(this.getInputStream()));
	}

	/**
	 *
	 * <p>
	 * Title: getParameterValues
	 * </p>
	 * <p>
	 * Description:
	 * </p>
	 *
	 * @param parameter
	 * @return
	 * @see javax.servlet.ServletRequestWrapper#getParameterValues(String)
	 */
	@Override
	public String[] getParameterValues(String parameter) {
		String[] values = parameterMap.get(parameter);
		if (values == null || values.length == 0) {
			return null;
		}
		//XSS过滤
		int count = values.length;
		String[] encodedValues = new String[count];
		for(int i = 0;i < count;i++) {
			encodedValues[i] = this.cleanXSS(values[i]);
		}
		return encodedValues;
	}

    @Override
    public Enumeration<String> getParameterNames() {
        Vector<String> vector = new Vector<String>(parameterMap.keySet());
        return vector.elements();
    }

	@Override
	public String getParameter(String parameter) {
		String[] results = parameterMap.get(parameter);
		if (results == null || results.length <= 0)
			return null;
		else {
			String value = results[0];
			return value == null ? null : this.cleanXSS(value);
		}
	}

	public void addHeader(String name,String value){
		headers.put(name, value);
	}

	@Override
	public String getHeader(String name) {
		String value = super.getHeader(name);
		if (headers.containsKey(name)){
			value = headers.get(name);
		}
		return value == null ? null : this.cleanXSS(value);
	}

	@Override
	public Enumeration<String> getHeaderNames() {
		List<String> names = Collections.list(super.getHeaderNames());
		names.addAll(headers.keySet());
		return Collections.enumeration(names);
	}

	@Override
	public Enumeration<String> getHeaders(String name) {
		List<String> list =  Collections.list(super.getHeaders(name));
		if (headers.containsKey(name)){
			list.add(headers.get(name));
		}
		return Collections.enumeration(list);
	}

	/**
	 * xss过滤
	 * @param value
	 * @return
	 */
	private String cleanXSS(String value) {
		if(null == value) {
			return null;
		}
		//value = value.replaceAll("<", "& lt;").replaceAll(">", "& gt;");
		//value = value.replaceAll("\\(", "& #40;").replaceAll("\\)", "& #41;");
		//value = value.replaceAll("'", "& #39;");
		value = value.replaceAll("eval\\((.*)\\)", "");
		value = value.replaceAll("[\\\"\\'][\\s]*javascript:(.*)[\\\"\\']", "\"\"");
		value = value.replaceAll("script", "");
		return value;
	}

}